import os
import xai
import logging as log
import warnings
import matplotlib.pyplot as plt
import sys, os
from util.commons import *
from util.ui import *
from util.model import *
from util.split import *
from util.dataset import *
from IPython.display import display, HTML
In this notebook a dataset named 'Risk Factors for Cervical Cancer'. The dataset was collected at 'Hospital Universitario de Caracas' in Caracas, Venezuela. The dataset comprises demographic information, habits, and historic medical records of 858 patients. Several patients decided not to answer some of the questions because of privacy concerns (missing values).
dataset, msg = get_dataset('cervical_cancer')
display(msg)
display(dataset.df)
The dataset will be used same as described here: https://christophm.github.io/interpretable-ml-book/cervical.html All unknown values (\?) are going to be set to 0.0.
df = dataset.df.drop(columns=['Smokes (packs/year)', 'STDs:condylomatosis', 'STDs:cervical condylomatosis', 'STDs:genital herpes',
'STDs:Hepatitis B', 'STDs:vulvo-perineal condylomatosis', 'Dx:HPV',
'STDs:molluscum contagiosum', 'STDs:syphilis', 'STDs:AIDS', 'Hinselmann',
'STDs:pelvic inflammatory disease', 'STDs:HPV', 'Dx:CIN', 'Dx', 'STDs:HIV',
'Schiller', 'STDs:vaginal condylomatosis', 'Dx:Cancer', 'Citology'], axis=1)
num_cols = ['Number of sexual partners', 'First sexual intercourse', 'Num of pregnancies', 'Smokes',
'Smokes (years)', 'Hormonal Contraceptives', 'Hormonal Contraceptives (years)', 'IUD',
'IUD (years)', 'STDs', 'STDs (number)', 'STDs: Time since first diagnosis',
'STDs: Time since last diagnosis']
df = normalize_undefined_values('?', df)
str_limit = 5
for col in df.columns:
if col in num_cols and len(df[col].unique()) > str_limit:
df[col] = df[col].astype('float')
elif col in num_cols and len(df[col].unique()) <= str_limit:
df[col] = df[col].astype(str)
df
Three visualization functions offered by the XAI module will be used for analyzing the dataset.
%matplotlib inline
plt.style.use('ggplot')
warnings.filterwarnings('ignore')
imbalanced_cols = ['Biopsy']
xai.imbalance_plot(df, *imbalanced_cols, categorical_cols=['Biopsy'])
xai.correlations(df, include_categorical=True, plot_type="matrix")
xai.correlations(df, include_categorical=True)
In the cell below the target variable is selected. The Biopsy serves as the gold standard for diagnosing cervical cancer, therefore we will use it as target.
df_X, df_y, msg = split_feature_target(df, "Biopsy")
df_y
In this step three models are going to be trained on this dataset. In the output below we can see classification reports for the trained models.
# Create three empty models
initial_models, msg = fill_empty_models(df_X, df_y, 3)
models = []
# Train model 1
model1 = initial_models[0]
msg = fill_model(model1, Algorithm.LOGISTIC_REGRESSION, Split(SplitTypes.IMBALANCED, None))
models.append(model1)
# Train model 2
model2 = initial_models[1]
msg = fill_model(model2, Algorithm.RANDOM_FOREST, Split(SplitTypes.IMBALANCED, None))
models.append(model2)
# Train model 3
model3 = initial_models[2]
msg = fill_model(model3, Algorithm.DECISION_TREE, Split(SplitTypes.IMBALANCED, None))
models.append(model3)
model_1 = models[0]
model_2 = models[1]
model_3 = models[2]
In the following steps we will use global interpretation techniques that help us to answer questions like how does a model behave in general? What features drive predictions and what features are completely useless. This data may be very important in understanding the model better. Most of the techniques work by investigating the conditional interactions between the target variable and the features on the complete dataset.
The importance of a feature is the increase in the prediction error of the model after we permuted the feature’s values, which breaks the relationship between the feature and the true outcome. A feature is “important” if permuting it increases the model error. This is because in that case, the model relied heavily on this feature for making right prediction. On the other hand, a feature is “unimportant” if permuting it doesn’t affect the error by much or doesn’t change it at all.
In the first case, we use ELI5, which does not permute the features but only visualizes the weight of each feature.
plot = generate_feature_importance_plot(FeatureImportanceType.ELI5, model_1)
display(plot)
plot = generate_feature_importance_plot(FeatureImportanceType.ELI5, model_2)
display(plot)
plot = generate_feature_importance_plot(FeatureImportanceType.ELI5, model_3)
display(plot)
print(generate_feature_importance_explanation(FeatureImportanceType.ELI5, models, 4))
%matplotlib inline
plt.rcParams['figure.figsize'] = [14, 15]
plt.style.use('ggplot')
warnings.filterwarnings('ignore')
_ = generate_feature_importance_plot(FeatureImportanceType.SKATER, model_1)
_ = generate_feature_importance_plot(FeatureImportanceType.SKATER, model_2)
_ = generate_feature_importance_plot(FeatureImportanceType.SKATER, model_3)
print('\n' + generate_feature_importance_explanation(FeatureImportanceType.SKATER, models, 4))
In the cell below we use the SHAP (SHapley Additive exPlanations). It uses a combination of feature contributions and game theory to come up with SHAP values. Then, it computes the global feature importance by taking the average of the SHAP value magnitudes across the dataset.
from shap import initjs
initjs()
%matplotlib inline
plt.style.use('ggplot')
warnings.filterwarnings('ignore')
generate_feature_importance_plot(FeatureImportanceType.SHAP, model_1)
generate_feature_importance_plot(FeatureImportanceType.SHAP, model_2)
generate_feature_importance_plot(FeatureImportanceType.SHAP, model_3)
print(generate_feature_importance_explanation(FeatureImportanceType.SHAP, models, 4))
The partial dependence plot (short PDP or PD plot) shows the marginal effect one or two features have on the predicted outcome of a machine learning model. A partial dependence plot can show whether the relationship between the target and a feature is linear, monotonic or more complex. For example, when applied to a linear regression model, partial dependence plots always show a linear relationship.
PDPBox is the first module that we use for ploting partial dependence.
generate_pdp_plots(PDPType.PDPBox, model_1, "Age", "None")
generate_pdp_plots(PDPType.PDPBox, model_1, "Age", "Number of sexual partners")
generate_pdp_plots(PDPType.PDPBox, model_2, "Age", "None")
generate_pdp_plots(PDPType.PDPBox, model_2, "Age", "Number of sexual partners")
generate_pdp_plots(PDPType.PDPBox, model_3, "Age", "None")
generate_pdp_plots(PDPType.PDPBox, model_3, "Age", "Number of sexual partners")
generate_pdp_plots(PDPType.SKATER, model_1, "Age", "Number of sexual partners")
generate_pdp_plots(PDPType.SKATER, model_2, "Age", "Number of sexual partners")
generate_pdp_plots(PDPType.SKATER, model_3, "Age", "Number of sexual partners")
generate_pdp_plots(PDPType.SHAP, model_1, "Age", "Number of sexual partners")
generate_pdp_plots(PDPType.SHAP, model_2, "Age", "Number of sexual partners")
generate_pdp_plots(PDPType.SHAP, model_3, "Age", "Number of sexual partners")
Local interpretation focuses on specifics of each individual and provides explanations that can lead to a better understanding of the feature contribution in smaller groups of individuals that are often overlooked by the global interpretation techniques. We will use two moduels for interpreting single instances - SHAP and LIME.
SHAP leverages the idea of Shapley values for model feature influence scoring. The technical definition of a Shapley value is the “average marginal contribution of a feature value over all possible coalitions.” In other words, Shapley values consider all possible predictions for an instance using all possible combinations of inputs. Because of this exhaustive approach, SHAP can guarantee properties like consistency and local accuracy. LIME, on the other hand, does not offer such guarantees.
LIME (Local Interpretable Model-agnostic Explanations) builds sparse linear models around each prediction to explain how the black box model works in that local vicinity. While treating the model as a black box, we perturb the instance we want to explain and learn a sparse linear model around it, as an explanation. LIME has the advantage over SHAP, that it is a lot faster.
examples = [] + get_test_examples(model_1, ExampleType.FALSELY_CLASSIFIED, 2)
examples = examples + get_test_examples(model_2, ExampleType.TRULY_CLASSIFIED, 2)
examples
print(get_example_information(model_1, examples[0]))
print(generate_single_instance_comparison(models, examples[0]))
explanation = explain_single_instance(LocalInterpreterType.LIME, model_1, examples[0])
print(generate_single_instance_explanation(LocalInterpreterType.LIME, model_1, examples[0]))
explanation.show_in_notebook(show_table=True, show_all=True)
explanation = explain_single_instance(LocalInterpreterType.SHAP, model_1, examples[0])
print(generate_single_instance_explanation(LocalInterpreterType.SHAP, model_1, examples[0]))
display(explanation)
explanation = explain_single_instance(LocalInterpreterType.LIME, model_2, examples[0])
print(generate_single_instance_explanation(LocalInterpreterType.LIME, model_2, examples[0]))
explanation.show_in_notebook(show_table=True, show_all=True)
explanation = explain_single_instance(LocalInterpreterType.SHAP, model_2, examples[0])
print(generate_single_instance_explanation(LocalInterpreterType.SHAP, model_2, examples[0]))
display(explanation)
explanation = explain_single_instance(LocalInterpreterType.LIME, model_3, examples[0])
print(generate_single_instance_explanation(LocalInterpreterType.LIME, model_3, examples[0]))
explanation.show_in_notebook(show_table=True, show_all=True)
explanation = explain_single_instance(LocalInterpreterType.SHAP, model_3, examples[0])
print(generate_single_instance_explanation(LocalInterpreterType.SHAP, model_3, examples[0]))
display(explanation)
print(get_example_information(model_1, examples[1]))
print(generate_single_instance_comparison(models, examples[1]))
explanation = explain_single_instance(LocalInterpreterType.LIME, model_1, examples[1])
print(generate_single_instance_explanation(LocalInterpreterType.LIME, model_1, examples[1]))
explanation.show_in_notebook(show_table=True, show_all=True)
explanation = explain_single_instance(LocalInterpreterType.SHAP, model_1, examples[1])
print(generate_single_instance_explanation(LocalInterpreterType.SHAP, model_1, examples[1]))
display(explanation)
explanation = explain_single_instance(LocalInterpreterType.LIME, model_2, examples[1])
print(generate_single_instance_explanation(LocalInterpreterType.LIME, model_2, examples[1]))
explanation.show_in_notebook(show_table=True, show_all=True)
explanation = explain_single_instance(LocalInterpreterType.SHAP, model_2, examples[1])
print(generate_single_instance_explanation(LocalInterpreterType.SHAP, model_2, examples[1]))
display(explanation)
explanation = explain_single_instance(LocalInterpreterType.LIME, model_3, examples[1])
print(generate_single_instance_explanation(LocalInterpreterType.LIME, model_3, examples[1]))
explanation.show_in_notebook(show_table=True, show_all=True)
explanation = explain_single_instance(LocalInterpreterType.SHAP, model_3, examples[1])
print(generate_single_instance_explanation(LocalInterpreterType.SHAP, model_3, examples[1]))
display(explanation)
print(get_example_information(model_1, examples[2]))
print(generate_single_instance_comparison(models, examples[2]))
explanation = explain_single_instance(LocalInterpreterType.LIME, model_1, examples[2])
print(generate_single_instance_explanation(LocalInterpreterType.LIME, model_1, examples[2]))
explanation.show_in_notebook(show_table=True, show_all=True)
explanation = explain_single_instance(LocalInterpreterType.SHAP, model_1, examples[2])
print(generate_single_instance_explanation(LocalInterpreterType.SHAP, model_1, examples[2]))
display(explanation)
explanation = explain_single_instance(LocalInterpreterType.LIME, model_2, examples[2])
print(generate_single_instance_explanation(LocalInterpreterType.LIME, model_2, examples[2]))
explanation.show_in_notebook(show_table=True, show_all=True)
explanation = explain_single_instance(LocalInterpreterType.SHAP, model_2, examples[2])
print(generate_single_instance_explanation(LocalInterpreterType.SHAP, model_2, examples[2]))
display(explanation)
explanation = explain_single_instance(LocalInterpreterType.LIME, model_3, examples[2])
print(generate_single_instance_explanation(LocalInterpreterType.LIME, model_3, examples[2]))
explanation.show_in_notebook(show_table=True, show_all=True)
explanation = explain_single_instance(LocalInterpreterType.SHAP, model_3, examples[2])
print(generate_single_instance_explanation(LocalInterpreterType.SHAP, model_3, examples[2]))
display(explanation)
print(get_example_information(model_1, examples[1]))
print(generate_single_instance_comparison(models, examples[1]))
explanation = explain_single_instance(LocalInterpreterType.LIME, model_1, examples[3])
print(generate_single_instance_explanation(LocalInterpreterType.LIME, model_1, examples[3]))
explanation.show_in_notebook(show_table=True, show_all=True)
explanation = explain_single_instance(LocalInterpreterType.SHAP, model_1, examples[3])
print(generate_single_instance_explanation(LocalInterpreterType.SHAP, model_1, examples[3]))
display(explanation)
explanation = explain_single_instance(LocalInterpreterType.LIME, model_2, examples[3])
print(generate_single_instance_explanation(LocalInterpreterType.LIME, model_2, examples[3]))
explanation.show_in_notebook(show_table=True, show_all=True)
explanation = explain_single_instance(LocalInterpreterType.SHAP, model_2, examples[3])
print(generate_single_instance_explanation(LocalInterpreterType.SHAP, model_2, examples[3]))
display(explanation)
explanation = explain_single_instance(LocalInterpreterType.LIME, model_3, examples[3])
print(generate_single_instance_explanation(LocalInterpreterType.LIME, model_3, examples[3]))
explanation.show_in_notebook(show_table=True, show_all=True)
explanation = explain_single_instance(LocalInterpreterType.SHAP, model_3, examples[3])
print(generate_single_instance_explanation(LocalInterpreterType.SHAP, model_3, examples[3]))
display(explanation)